import os
import json
import inspect
import argparse
import pandas as pd
from typing import Literal
from tqdm.auto import tqdm
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from _models.model import get_embedding_func_batched
from _datasets.data import DatasetConfig
from utils.transform_utils import *
from utils.string_utils import *
from utils.metrics import *


class HumanPrefExperimentConfig:

    def __init__(
        self,
        mode: Literal["scores", "comparisons"],
        num_examples: int,
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,
    ):
        assert mode in ["scores", "comparisons"], "Mode is 'scores' or 'comparisons'."

        self.mode = mode
        self.model_name = model_name
        dataset_name = "openai_summarize_" + mode
        self.dataset_config = DatasetConfig(dataset_name, num_examples)
        self.dataset = self.dataset_config.get_dataset(True, max_length)
        print(f"Dataset {dataset_name} loaded.")

        if mode == "scores":
            self.rel_columns = ["original", "summary"]
        else:
            self.rel_columns = ["original", "summary1", "summary2"]

        self.embedding_func = get_embedding_func_batched(model_name)
        self.similarity_data = pd.DataFrame(self.dataset)
        self.results = {}

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def run(self):
        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.calculate_similarities()
        print("Calculated similarities.")

        self.fit_ensembling()
        print("Fitted ensembling.")

        self.get_results()
        print("Got results.")

        # Save the similarity data to a CSV file in the model-specific directory
        data_file_path = f"{self.model_data_path}/{self.dataset_config.name}.pkl"
        self.similarity_data.to_pickle(data_file_path)
        print(f"Saved data to {data_file_path}.")

        # Save the results to a JSON file in the model-specific directory
        results_file_path = f"{self.model_data_path}/{self.dataset_config.name}.json"
        with open(results_file_path, "w") as f:
            self.results = {
                k1: {k2: float(v2) for k2, v2 in v1.items()}
                for k1, v1 in self.results.items()
            }
            f.write(json.dumps(self.results))
        print(f"Saved results to {results_file_path}.")

    def generate_embeddings(self, embedding_func, **kwargs):
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        for column in self.rel_columns:
            embeddings_column = f"embeddings_{column}"
            if column in self.similarity_data:
                embeds = embedding_func(
                    prompts=self.similarity_data[column].dropna().tolist(),
                    pbar=False,
                    **kwargs,
                )
                self.similarity_data[embeddings_column] = (
                    embeds if isinstance(embeds, list) else embeds.tolist()
                )
            else:
                print(f"Warning: Column {column} does not exist in the DataFrame")

    def calculate_similarities(self):
        original = self.similarity_data["original"]
        embeddings_original = self.similarity_data["embeddings_original"]

        for column in self.rel_columns:
            if column != "original":
                print(f'\tCalculating similarity for column "{column}"')
                self.similarity_data[f"cosine_similarity_{column}"] = cosine_similarity(
                    embeddings_original, self.similarity_data[f"embeddings_{column}"]
                )
                print(f"\t\tCalculated cosinse similarity")
                self.similarity_data[f"jaccard_similarity_{column}"] = (
                    jaccard_similarity(original, self.similarity_data[column])
                )
                print(f"\t\tCalculated jaccard similarity")
                self.similarity_data[f"levenshtein_similarity_{column}"] = (
                    levenshtein_ratio(original, self.similarity_data[column])
                )
                print(f"\t\tCalculated levenshtein similarity")
                self.similarity_data[f"bm25_similarity_{column}"] = bm25_score(
                    original, self.similarity_data[column]
                )
                print(f"\t\tCalculated bm25 similarity")
                self.similarity_data[f"rouge_similarity_{column}"] = rouge_score(
                    original, self.similarity_data[column]
                )
                print(f"\t\tCalculated rouge similarity")

        if "choice" in self.similarity_data.columns:
            for metric in metrics:
                metric = f"{metric}_similarity"
                metric_summary1 = self.similarity_data[f"{metric}_summary1"]
                metric_summary2 = self.similarity_data[f"{metric}_summary2"]

                metric_choice = metric_summary1 > metric_summary2
                self.similarity_data[f"{metric}_choice"] = np.where(metric_choice, 0, 1)

    def fit_ensembling(self):
        if self.mode == "scores":
            X_cols = [metric + "_similarity_summary" for metric in metrics]
            y_col = "overall"
        else:
            X_cols = [metric + "_similarity_summary1" for metric in metrics]
            X_cols += [metric + "_similarity_summary2" for metric in metrics]
            y_col = "choice"

        X = self.similarity_data[X_cols]
        y = self.similarity_data[y_col]

        if self.mode == "scores":
            overalls, accuracies, coverages, coherences = [], [], [], []
            for i in tqdm(range(1000), desc="Ensembling"):
                self.ensemble = LinearRegression(fit_intercept=False)
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.2, random_state=i
                )
                self.ensemble.fit(X_train, y_train)
                y_pred = self.ensemble.predict(X)
                self.similarity_data["ensembled_similarity_summary"] = y_pred
                corrs = self.similarity_data.loc[y_test.index][
                    [
                        "ensembled_similarity_summary",
                        "overall",
                        "accuracy",
                        "coverage",
                        "coherence",
                    ]
                ].corr()[f"ensembled_similarity_summary"]

                if (
                    np.isnan(corrs["overall"])
                    or np.isnan(corrs["accuracy"])
                    or np.isnan(corrs["coverage"])
                    or np.isnan(corrs["coherence"])
                ):
                    continue
                overalls.append(corrs["overall"])
                accuracies.append(corrs["accuracy"])
                coverages.append(corrs["coverage"])
                coherences.append(corrs["coherence"])

            self.results["ensembled_similarity"] = {
                "overall": np.mean(overalls),
                "accuracy": np.mean(accuracies),
                "coverage": np.mean(coverages),
                "coherence": np.mean(coherences),
            }
        else:
            accuracies, precisions, recalls, f1s = [], [], [], []
            for i in tqdm(range(1000), desc="Ensembling"):
                self.ensemble = RandomForestClassifier(random_state=i)
                X_train, X_test, y_train, y_test = train_test_split(
                    X, y, test_size=0.2, random_state=i
                )
                self.ensemble.fit(X_train, y_train)
                y_test_pred = self.ensemble.predict(X_test)
                accuracies.append(accuracy_score(y_test, y_test_pred))
                precisions.append(precision_score(y_test, y_test_pred, zero_division=1))
                recalls.append(recall_score(y_test, y_test_pred, zero_division=1))
                f1s.append(f1_score(y_test, y_test_pred, zero_division=1))
            self.results["ensembled_similarity"] = {
                "accuracy": np.mean(accuracies),
                "precision": np.mean(precisions),
                "recall": np.mean(recalls),
                "f1": np.mean(f1s),
            }

    def get_results(self):
        ensembled = any(["ensembled" in c for c in self.similarity_data.columns])
        curr_metrics = metrics + ["ensembled"] if ensembled else metrics
        if self.mode == "scores":
            for metric in curr_metrics:
                metric = f"{metric}_similarity"
                corrs = self.similarity_data[
                    [
                        f"{metric}_summary",
                        "overall",
                        "accuracy",
                        "coverage",
                        "coherence",
                    ]
                ].corr()[f"{metric}_summary"]
                corrs = 0.5 * (corrs + 1)
                self.results[metric] = {
                    "overall": corrs["overall"],
                    "accuracy": corrs["accuracy"],
                    "coverage": corrs["coverage"],
                    "coherence": corrs["coherence"],
                }
        else:
            true_choice = self.similarity_data["choice"]
            for metric in curr_metrics:
                metric = f"{metric}_similarity"
                self.results[metric] = {
                    "accuracy": accuracy_score(
                        true_choice, self.similarity_data[f"{metric}_choice"]
                    ),
                    "precision": precision_score(
                        true_choice, self.similarity_data[f"{metric}_choice"]
                    ),
                    "recall": recall_score(
                        true_choice, self.similarity_data[f"{metric}_choice"]
                    ),
                    "f1": f1_score(
                        true_choice, self.similarity_data[f"{metric}_choice"]
                    ),
                }
        return self.results


def main(
    mode="scores",
    num_examples=5,
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = HumanPrefExperimentConfig(
        mode,
        num_examples,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="scores")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    mode = args.mode
    num_examples = args.num_examples
    model_name = args.model_name
    max_length = args.max_length

    main(
        mode,
        num_examples,
        model_name,
        max_length,
    )
